-
Notifications
You must be signed in to change notification settings - Fork 23
AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion #354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
[ROCm] support v2 bwd native padding
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, I think we can try to remove all memset except for dq, dq_acc. We can confirm with aiter/ck people
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
| dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream, | ||
| h, hg, d_qk, | ||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b, | ||
| b, h, hg, d_qk, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have gqa/mqa + MLA testcases w and w/ padding? If not, can we create those to verify this flow is actually working
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will work on trying to add one in the JAX side -- for now I've added one on the TE side that isn't able to run due to too few backends supporting it, but that may change e.g. as we update AOTriton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp
Outdated
Show resolved
Hide resolved
|
Let's also add how to use the runtime segment/max seqlen in readme under https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#fused-attention-backends-on-rocm. Remind our customers that this will break the cudagraph |
@wangye805 I've now updated the readme, but let me know if you have specific thoughts on it. |
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look at several unresolved conversation previously
- Updated debug message for BSHD-->THD conversion
- Added env variable to gate FWD output memset for padding
- Removed guards on memsets for d{Q,K,V} matrices
|
@Micky774 Could you rebase/merge latest dev to incorporate the hot fixes for sgpu tests? |
|
pytorch test_numerics also shows some fused-attn related failures: Not sure whether this is related to our decision to remove memsettings. |
Those failures were due to a mix of not correctly dispatching to the |
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For those newly added hybrid qkv formats in upstream (NVTE_SBHD_2BSHD, NVTE_BSHD_2SBHD, NVTE_THD_2BSHD, and NVTE_THD_2SBHD): in addition to the SBHD_2BSHD pytest failures, are we able to correctly handle all other 3? Or is there only SBHD_2BSHD pytests now?
NV upstream is separating format and is_ragged on q/kv and do subsequent processings accordingly:
TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Lines 79 to 82 in 32e2d1d
| NVTE_QKV_Format q_format = nvte_get_q_format(layout); | |
| NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); | |
| bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); | |
| bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); |
Maybe we can try similar technique. If I recall correctly, we need padding/unpadding for just q in SBHD_2BSHD and for just k/v in BSHD_2SBHD.
Or it's okay if you want to leave this for another PR.
By the way, there is an "extra line" comment you may have ignored :-)
| max_tokens_q, | ||
| devPtrQWithoutPadding, | ||
| q_stride[1], (is_ragged? q_stride[2] : std::min(q_stride[0], q_stride[2])), | ||
| q_stride[1], q_stride[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need std::min(q_stride[0], q_stride[2]) for SBHD_2BSHD formats with padding?
| max_tokens_q, max_tokens_kv, | ||
| devPtrQWithoutPadding, | ||
| q_stride[1], (is_ragged? q_stride[2] : std::min(q_stride[0], q_stride[2])), | ||
| q_stride[1], q_stride[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar question here for SBHD_2BSHD "hybrid-style" format
| dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream, | ||
| h, hg, d_qk, | ||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b, | ||
| b, h, hg, d_qk, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it
| NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream)); | ||
| // for pad between seqs case, we need to reset all dq, dk, dv | ||
| if(pad_between_seqs){ | ||
| if(is_padding && nvte_ck_zero_out_pad){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on Wen's analysis, dq, dk, dv requires zero out padding locations for subsequent grad computation. So they should be memset without conditions (nvte_ck_zero_out_pad can only control O tensor)
Description
Feature update PR which includes several iterative changes for client-driven optimization targets. This PR includes both API changes for CK/AITER as well as changes in internal integration. See the list of changes for specifics.
Note that this will not be ready for merger until ROCm/aiter#1212 is merged in and this PR's AITER commit is updated.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
max_seqlencalculation gated by new env varNVTE_CK_RUNTIME_MAX_SEQLENv3_api_checksupport (temporary)pad_between_seqs(need to follow-up with a PR cleaning up test suite for oldpad_between_seqsedge-cases)NVTE_CK_RUNTIME_NUM_SEGMENTSto guard runtime-calculation of the number of segments in the JAX integrationChecklist: